{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddleslim.nas import SANAS\n",
    "import numpy as np\n",
    "\n",
    "BATCH_SIZE=96\n",
    "SERVER_ADDRESS = \"\"\n",
    "PORT = 8377\n",
    "SEARCH_STEPS = 300\n",
    "RETAIN_EPOCH=30\n",
    "MAX_PARAMS=3.77\n",
    "IMAGE_SHAPE=[3, 32, 32]\n",
    "AUXILIARY = True\n",
    "AUXILIARY_WEIGHT= 0.4\n",
    "TRAINSET_NUM = 50000\n",
    "LR = 0.025\n",
    "MOMENTUM = 0.9\n",
    "WEIGHT_DECAY = 0.0003\n",
    "DROP_PATH_PROBILITY = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-23 12:28:09,752-INFO: range table: ([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14])\n",
      "2020-02-23 12:28:09,754-INFO: ControllerServer - listen on: [127.0.0.1:8377]\n",
      "2020-02-23 12:28:09,756-INFO: Controller Server run...\n"
     ]
    }
   ],
   "source": [
    "config = [('DartsSpace')]\n",
    "sa_nas = SANAS(config, server_addr=(SERVER_ADDRESS, PORT), search_steps=SEARCH_STEPS, is_server=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters_in_MB(all_params, prefix='model'):\n",
    "    parameters_number = 0\n",
    "    for param in all_params:\n",
    "        if param.name.startswith(\n",
    "                prefix) and param.trainable and 'aux' not in param.name:\n",
    "            parameters_number += np.prod(param.shape)\n",
    "    return parameters_number / 1e6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_data_loader(IMAGE_SHAPE, is_train):\n",
    "    image = fluid.data(\n",
    "        name=\"image\", shape=[None] + IMAGE_SHAPE, dtype=\"float32\")\n",
    "    label = fluid.data(name=\"label\", shape=[None, 1], dtype=\"int64\")\n",
    "    data_loader = fluid.io.DataLoader.from_generator(\n",
    "        feed_list=[image, label],\n",
    "        capacity=64,\n",
    "        use_double_buffer=True,\n",
    "        iterable=True)\n",
    "    drop_path_prob = ''\n",
    "    drop_path_mask = ''\n",
    "    if is_train:\n",
    "        drop_path_prob = fluid.data(\n",
    "            name=\"drop_path_prob\", shape=[BATCH_SIZE, 1], dtype=\"float32\")\n",
    "        drop_path_mask = fluid.data(\n",
    "            name=\"drop_path_mask\",\n",
    "            shape=[BATCH_SIZE, 20, 4, 2],\n",
    "            dtype=\"float32\")\n",
    "\n",
    "    return data_loader, image, label, drop_path_prob, drop_path_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_program(main_program, startup_program, IMAGE_SHAPE, archs, is_train):\n",
    "    with fluid.program_guard(main_program, startup_program):\n",
    "        data_loader, data, label, drop_path_prob, drop_path_mask = create_data_loader(\n",
    "            IMAGE_SHAPE, is_train)\n",
    "        logits, logits_aux = archs(data, drop_path_prob, drop_path_mask,\n",
    "                                   is_train, 10)\n",
    "        top1 = fluid.layers.accuracy(input=logits, label=label, k=1)\n",
    "        top5 = fluid.layers.accuracy(input=logits, label=label, k=5)\n",
    "        loss = fluid.layers.reduce_mean(\n",
    "            fluid.layers.softmax_with_cross_entropy(logits, label))\n",
    "\n",
    "        if is_train:\n",
    "            if AUXILIARY:\n",
    "                loss_aux = fluid.layers.reduce_mean(\n",
    "                    fluid.layers.softmax_with_cross_entropy(logits_aux, label))\n",
    "                loss = loss + AUXILIARY_WEIGHT * loss_aux\n",
    "            step_per_epoch = int(TRAINSET_NUM / BATCH_SIZE)\n",
    "            learning_rate = fluid.layers.cosine_decay(LR, step_per_epoch, RETAIN_EPOCH)\n",
    "            fluid.clip.set_gradient_clip(\n",
    "                clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))\n",
    "            optimizer = fluid.optimizer.MomentumOptimizer(\n",
    "                learning_rate,\n",
    "                MOMENTUM,\n",
    "                regularization=fluid.regularizer.L2DecayRegularizer(\n",
    "                    WEIGHT_DECAY))\n",
    "            optimizer.minimize(loss)\n",
    "            outs = [loss, top1, top5, learning_rate]\n",
    "        else:\n",
    "            outs = [loss, top1, top5]\n",
    "    return outs, data_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(main_prog, exe, epoch_id, train_loader, fetch_list):\n",
    "    loss = []\n",
    "    top1 = []\n",
    "    top5 = []\n",
    "    for step_id, data in enumerate(train_loader()):\n",
    "        devices_num = len(data)\n",
    "        if DROP_PATH_PROBILITY > 0:\n",
    "            feed = []\n",
    "            for device_id in range(devices_num):\n",
    "                image = data[device_id]['image']\n",
    "                label = data[device_id]['label']\n",
    "                drop_path_prob = np.array(\n",
    "                    [[DROP_PATH_PROBILITY * epoch_id / RETAIN_EPOCH]\n",
    "                     for i in range(BATCH_SIZE)]).astype(np.float32)\n",
    "                drop_path_mask = 1 - np.random.binomial(\n",
    "                    1, drop_path_prob[0],\n",
    "                    size=[BATCH_SIZE, 20, 4, 2]).astype(np.float32)\n",
    "                feed.append({\n",
    "                    \"image\": image,\n",
    "                    \"label\": label,\n",
    "                    \"drop_path_prob\": drop_path_prob,\n",
    "                    \"drop_path_mask\": drop_path_mask\n",
    "                })\n",
    "        else:\n",
    "            feed = data\n",
    "        loss_v, top1_v, top5_v, lr = exe.run(\n",
    "            main_prog, feed=feed, fetch_list=[v.name for v in fetch_list])\n",
    "        loss.append(loss_v)\n",
    "        top1.append(top1_v)\n",
    "        top5.append(top5_v)\n",
    "        if step_id % 10 == 0:\n",
    "            print(\n",
    "                \"Train Epoch {}, Step {}, Lr {:.8f}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}\".\n",
    "                format(epoch_id, step_id, lr[0], np.mean(loss), np.mean(top1), np.mean(top5)))\n",
    "    return np.mean(top1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def valid(main_prog, exe, epoch_id, valid_loader, fetch_list):\n",
    "    loss = []\n",
    "    top1 = []\n",
    "    top5 = []\n",
    "    for step_id, data in enumerate(valid_loader()):\n",
    "        loss_v, top1_v, top5_v = exe.run(\n",
    "            main_prog, feed=data, fetch_list=[v.name for v in fetch_list])\n",
    "        loss.append(loss_v)\n",
    "        top1.append(top1_v)\n",
    "        top5.append(top5_v)\n",
    "        if step_id % 10 == 0:\n",
    "            print(\n",
    "                \"Valid Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}\".\n",
    "                format(epoch_id, step_id, np.mean(loss), np.mean(top1), np.mean(top5)))\n",
    "    return np.mean(top1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2020-02-23 12:28:57,462-INFO: current tokens: [5, 5, 5, 5, 5, 12, 7, 7, 7, 7, 7, 7, 7, 10, 10, 10, 10, 10, 10, 10]\n"
     ]
    }
   ],
   "source": [
    "archs = sa_nas.next_archs()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_program = fluid.Program()\n",
    "test_program = fluid.Program()\n",
    "startup_program = fluid.Program()\n",
    "train_fetch_list, train_loader = build_program(train_program, startup_program, IMAGE_SHAPE, archs, is_train=True)\n",
    "test_fetch_list, test_loader = build_program(test_program, startup_program, IMAGE_SHAPE, archs, is_train=False)\n",
    "test_program = test_program.clone(for_test=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "place = fluid.CPUPlace()\n",
    "exe = fluid.Executor(place)\n",
    "exe.run(startup_program)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<paddle.fluid.reader.GeneratorLoader at 0x7fddc8fe7cd0>"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_reader = paddle.fluid.io.batch(paddle.reader.shuffle(paddle.dataset.cifar.train10(cycle=False), buf_size=1024), batch_size=BATCH_SIZE, drop_last=True)\n",
    "test_reader = paddle.fluid.io.batch(paddle.dataset.cifar.test10(cycle=False), batch_size=BATCH_SIZE, drop_last=False)\n",
    "train_loader.set_sample_list_generator(train_reader, places=place)\n",
    "test_loader.set_sample_list_generator(test_reader, places=place)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch 0, Step 0, Lr 0.02500000, loss 3.310467, acc_1 0.062500, acc_5 0.468750\n"
     ]
    }
   ],
   "source": [
    "for epoch_id in range(RETAIN_EPOCH):\n",
    "    train_top1 = train(train_program, exe, epoch_id, train_loader, train_fetch_list)\n",
    "    print(\"TRAIN: Epoch {}, train_acc {:.6f}\".format(epoch_id, train_top1))\n",
    "    valid_top1 = valid(test_program, exe, epoch_id, test_loader, test_fetch_list)\n",
    "    print(\"TEST: Epoch {}, valid_acc {:.6f}\".format(epoch_id, valid_top1))\n",
    "    valid_top1_list.append(valid_top1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
